#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import pickle
from vedo import Sphere, Plotter, Text2D, settings
import numpy as np
from scipy.spatial import distance

from fugw.fugw import FUGW

# %%
sphere_mesh = Sphere(r=1, res=40).color("white")

vertices_sphere = np.array(sphere_mesh.points())
n = vertices_sphere.shape[0]

# %%
# Compute distances between vertices
C = distance.cdist(vertices_sphere, vertices_sphere, metric="cosine")


# %%
# 1 - Define source and target signals over spheres
def pdf_vmf(x, mu, kappa):
    return np.exp(kappa * np.dot(mu, x.T))


c1 = np.array([0, 1, 0])
c2 = np.array([0, 1 / np.sqrt(2), 1 / np.sqrt(2)])
c3 = np.array([0.5, 1, 0])
c3 = c3 / np.linalg.norm(c3)
f1 = pdf_vmf(vertices_sphere, c1, 20)

f1 = f1 / np.max(f1)
f2 = pdf_vmf(vertices_sphere, c2, 140)
f2 = f2 / np.max(f2)

# Source signal (two von Mises)
f = f1 + 0.9 * f2

f3 = pdf_vmf(vertices_sphere, c3, 20)
# Target signal (one von Mises)
f3 = f3 / np.max(f3)


# %%
# 2 - Compute FUGW alignments
# (either train or load it)


# %%
# 2.1 - BALANCED CASE
solver_balanced = FUGW()

# %%
# Load model...
# with open(
#    "./balanced_model.pkl",
#    "rb",
# ) as model_file:
#    solver_balanced = pickle.load(model_file)


# %%
# ...or train and save it
solver_balanced.alpha = 0.5
solver_balanced.rho = (100, 100, 0, 0)
solver_balanced.fit(f.reshape(1, -1), f3.reshape(1, -1), C, C)

marg1 = (np.sum(solver_balanced.pi, axis=1) * n - 1) * 100
marg2 = (np.sum(solver_balanced.pi, axis=0) * n - 1) * 100
f_transported1 = solver_balanced.transform(f)

with open(
    "./balanced_model.pkl",
    "wb",
) as model_file:
    pickle.dump(solver_balanced, model_file)


# %%
# 2.2 - UNBALANCED CASE
solver_unbalanced = FUGW()


# %%
# Load model...
# with open(
#    "./unbalanced_model.pkl",
#    "rb",
# ) as model_file:
#    solver_unbalanced = pickle.load(model_file)

# %%
# ...or train and save it
solver_unbalanced.rho = (1, 1, 0, 0)
solver_unbalanced.fit(f.reshape(1, -1), f3.reshape(1, -1), C, C)

marg3 = (np.sum(solver_unbalanced.pi, axis=1) * n - 1) * 100
marg4 = (np.sum(solver_unbalanced.pi, axis=0) * n - 1) * 100
f_transported2 = solver_unbalanced.transform(f)

with open(
    "./unbalanced_model.pkl",
    "wb",
) as model_file:
    pickle.dump(solver_unbalanced, model_file)


# %%
# 3 - Visualise results
settings.immediateRendering = False  # faster for multi-renderers

# %%
vp1 = Plotter(shape=(2, 4), axes=4)
sphere_mesh.rotateZ(20).rotateX(60)

offset_mass = 20

sph1 = sphere_mesh.clone(deep=False).cmap("plasma", f, vmin=0, vmax=1)
legend = Text2D("Source", font="Times", s=2)
vp1.show(sph1, legend, at=0, zoom="tight")

sph2 = sphere_mesh.clone(deep=False).cmap("plasma", f3, vmin=0, vmax=1)
legend = Text2D("Target", font="Times", s=2)
vp1.show(sph2, legend, at=4)

sph3 = sphere_mesh.clone(deep=False).cmap("plasma", f_transported1)
legend = Text2D("Balancing constraint \rho=100", font="Times", s=2)
vp1.show(sph3, legend, at=1)

sph4 = sphere_mesh.clone(deep=False).cmap(
    "RdBu_r", marg1, vmin=-offset_mass, vmax=offset_mass
)  # .addScalarBar(tformat='%.1f')
sph4.addScalarBar3D(title="% mass change", c="k", labelSize=2, titleSize=2)
sph4.scalarbar.scale(0.8).x(1.3).z(0)
legend = Text2D("Mass change on source", font="Times", s=2)
vp1.show(sph4, legend, at=2)
sph5 = sphere_mesh.clone(deep=False).cmap(
    "RdBu_r", marg2, vmin=-offset_mass, vmax=offset_mass
)  # .addScalarBar(tformat='%.1f')
sph5.addScalarBar3D(title="% mass change", c="k", labelSize=2, titleSize=2)
sph5.scalarbar.scale(0.8).x(1.3).z(0)
legend = Text2D("Mass change on target", font="Times", s=2)
vp1.show(sph5, legend, at=3)

sph6 = sphere_mesh.clone(deep=False).cmap("plasma", f_transported2)
legend = Text2D("Balancing constraint \rho=1", font="Times", s=2)
vp1.show(sph6, legend, at=5)

sph7 = sphere_mesh.clone(deep=False).cmap(
    "RdBu_r", marg3, vmin=-offset_mass, vmax=offset_mass
)  # .addScalarBar(tformat='%.1f')
sph7.addScalarBar3D(title="% mass change", c="k", labelSize=2, titleSize=2)
sph7.scalarbar.scale(0.8).x(1.3).z(0)
legend = Text2D("Mass change on source", font="Times", s=2)
vp1.show(sph7, legend, at=6)
sph8 = sphere_mesh.clone(deep=False).cmap(
    "RdBu_r", marg4, vmin=-offset_mass, vmax=offset_mass
)  # .addScalarBar(tformat='%.1f')
sph8.addScalarBar3D(title="% mass change", c="k", labelSize=2, titleSize=2)
sph8.scalarbar.scale(0.8).x(1.3).z(0)
legend = Text2D("Mass change on target", font="Times", s=2)
vp1.show(sph8, legend, at=7, interactive=1)

# %%
vp1.render()
